data("EPSG")

# load in prepared PA map ------
map = read_rds(here("data/PA/pa_map.rds"))

# run SMC and MCMC -----

## SMC  ---------
N_seq = 50*round(10^seq(2, 3.6, by=0.4) / 50)

plans_smc = NULL

res_smc = map_dfr(N_seq, function(N) {
    cat("Running SMC for N =", N, "\n")
    
    pl_smc = redist_smc(map, N, counties=county, runs=4L, ncores=2, 
                        seq_alpha=0.7, pop_temper=0.01)
    if (N == 4000) plans_smc <<- pl_smc
    
    pl_smc_s = pl_smc %>%
        subset_sampled() %>%
        mutate(comp = comp_frac_kept(., map),
               spl = county_splits(map, county),
               dem = group_frac(map, NDV, NDV+NRV)) %>%
        group_by(chain, draw) %>%
        summarize(comp = comp[1], 
                  spl = spl[1],
                  edem = sum(pnorm(dem, 0.5, 0.035)),
                  .groups="drop")
    
    diagn = attr(pl_smc, "diagnostics")
    vi = plans_diversity(pl_smc, 150L)
    
    ests_comp = tapply(pl_smc_s$comp, pl_smc_s$chain, mean)
    ests_spl = tapply(pl_smc_s$spl, pl_smc_s$chain, mean)
    ests_edem = tapply(pl_smc_s$edem, pl_smc_s$chain, mean)
    
    tibble(alg="smc", iter=N, 
           time = mean(map_dbl(diagn, ~ .$runtime)),
           n_eff = mean(map_dbl(diagn, ~ .$n_eff)),
           vi_q10 = quantile(vi, 0.1),
           vi_q90 = quantile(vi, 0.9),
           rhat_comp = redist:::diag_rhat(pl_smc_s$comp, pl_smc_s$chain),
           rhat_spl = redist:::diag_rhat(pl_smc_s$spl, pl_smc_s$chain),
           rhat_edem = redist:::diag_rhat(pl_smc_s$edem, pl_smc_s$chain),
           est_comp = mean(ests_comp), sd_comp = sd(ests_comp),
           est_spl = mean(ests_spl), sd_spl = sd(ests_spl),
           est_edem = mean(ests_edem), sd_edem = sd(ests_edem))
})

plans_smc_s = plans_smc %>%
    subset_sampled() %>%
    mutate(comp = comp_frac_kept(., map),
           spl = county_splits(map, county),
           dem = group_frac(map, NDV, NDV+NRV)) %>%
    group_by(chain, draw) %>%
    summarize(comp = comp[1], 
              spl = spl[1],
              edem = sum(pnorm(dem, 0.5, 0.035)),
              .groups="drop") %>%
    as_tibble() %>%
    `attr<-`("plans", NULL)

## MCMC  ---------
plans_mcmc = redist_mergesplit_parallel(map, 8000 + 500, warmup=500, chains=4, 
                                        counties=county, init_plan="sample", init_name=FALSE)

plans_mcmc_s = plans_mcmc %>%
    mutate(comp = comp_frac_kept(., map),
           spl = county_splits(map, county),
           dem = group_frac(map, NDV, NDV+NRV)) %>%
    group_by(chain, draw) %>%
    summarize(comp = comp[1], 
              spl = spl[1],
              edem = sum(pnorm(dem, 0.5, 0.035)),
              .groups="drop") %>%
    as_tibble() %>%
    `attr<-`("plans", NULL)

N_seq = round(10^seq(2, 4.0, by=0.172))

# create as-if-run-for-fewer results
time_iter_mcmc = mean(map_dbl(attr(plans_mcmc, "diagnostics"), ~ .$runtime)) / 8500
res_mcmc = map_dfr(N_seq, function(N) {
    pl_tmp = group_by(plans_mcmc, chain) %>%
        slice_head(n=N*18) %>%
        ungroup()
    vi = plans_diversity(pl_tmp, 150L)
    
    pl_tmp_s = group_by(plans_mcmc_s, chain) %>%
        slice_head(n=N) %>%
        ungroup()
    
    ests_comp = tapply(pl_tmp_s$comp, pl_tmp_s$chain, mean)
    ests_spl = tapply(pl_tmp_s$spl, pl_tmp_s$chain, mean)
    ests_edem = tapply(pl_tmp_s$edem, pl_tmp_s$chain, mean)
    
    tibble(alg="mcmc", iter=N,
           time = (500 + N) * time_iter_mcmc,
           vi_q10 = quantile(vi, 0.1),
           vi_q90 = quantile(vi, 0.9),
           rhat_comp = redist:::diag_rhat(pl_tmp_s$comp, pl_tmp_s$chain, split=TRUE),
           rhat_spl = redist:::diag_rhat(pl_tmp_s$spl, pl_tmp_s$chain, split=TRUE),
           rhat_edem = redist:::diag_rhat(pl_tmp_s$edem, pl_tmp_s$chain, split=TRUE),
           est_comp = mean(ests_comp), sd_comp = sd(ests_comp),
           est_spl = mean(ests_spl), sd_spl = sd(ests_spl),
           est_edem = mean(ests_edem), sd_edem = sd(ests_edem))
})

res = bind_rows(res_mcmc, res_smc) %>%
    mutate(rse_comp = sd_comp / tail(sd_comp, 1),
           rse_edem = sd_edem / tail(sd_edem, 1)) %>%
    # mutate(across(starts_with("rhat"), ~ . - 1)) %>%
    select(-n_eff, -ends_with("_spl")) %>%
    pivot_longer(rhat_comp:rse_edem, names_to=c("var", "stat"), names_sep="_")

##  efficiency comparison plots -----
xbrk = c(100, 300, 1000, 3000, 8000)
p_rhat_iter = filter(res, var=="rhat") %>%
    mutate(stat = c(comp="Fraction of edges removed", edem="Expected Democratic seats")[stat]) %>%
ggplot(aes(iter, value, color=toupper(alg), shape=toupper(alg), lty=stat)) +
    geom_hline(yintercept=1.0, color="#00000077") +
    geom_line(linewidth=0.5) +
    geom_point(size=1.2) +
    scale_x_continuous(labels=comma, trans="log10", breaks=xbrk) +
    labs(x="Sample size", y=NULL, title=expression("(a) "*hat(R)), 
         lty="Statistic", color="Algorithm", shape="Algorithm") +
    coord_cartesian(ylim=c(1, 2.0)) +
    scale_color_manual(values=c(PAL[2], "#444444")) +
    theme_bw(base_family="Times", base_size=10)

p_se_iter = filter(res, var=="rse") %>%
    mutate(stat = c(comp="Fraction of edges removed", edem="Expected Democratic seats")[stat]) %>%
ggplot(aes(iter, value, color=toupper(alg), shape=toupper(alg), lty=stat)) +
    geom_hline(yintercept=0.0, color="#00000077") +
    geom_hline(yintercept=1.0, color="#00000055", lty="longdash") +
    geom_line(linewidth=0.5) +
    geom_point(size=1.2) +
    scale_x_continuous(labels=comma, trans="log10", breaks=xbrk) +
    labs(x="Sample size", y=NULL, title="(b) Normalized\n standard errors", 
         lty="Statistic", color="Algorithm", shape="Algorithm") +
    scale_y_continuous(limits=c(0, NA)) +
    scale_color_manual(values=c(PAL[2], "#444444")) +
    theme_bw(base_family="Times", base_size=10)

p_iter_time = distinct(res, alg, time, iter) %>%
ggplot(aes(iter, time/60, color=toupper(alg), shape=toupper(alg))) +
    geom_line(linewidth=0.5) +
    geom_point(size=1.2) +
    scale_color_manual(values=c(PAL[2], "#444444")) +
    scale_y_continuous(trans="log10") +
    scale_x_continuous(labels=comma, trans="log10", breaks=xbrk) +
    labs(x="Sample size", y="Time (minutes)", title="(d) Sampling time", 
         color="Algorithm", shape="Algorithm") +
    theme_bw(base_family="Times", base_size=10)

p_vi_iter = distinct(res, alg, iter, vi_q10) %>%
    ggplot(aes(iter, vi_q10, color=toupper(alg), shape=toupper(alg))) +
    geom_line(linewidth=0.5) +
    geom_point(size=1.2) +
    scale_color_manual(values=c(PAL[2], "#444444")) +
    scale_y_continuous(trans="log10") +
    scale_x_continuous(labels=comma, trans="log10", breaks=xbrk) +
    labs(x="Sample size", y="Pairwise VI distance",
         title="(c) Combined sample\n diversity, 1st decile", 
         color="Algorithm", shape="Algorithm") +
    theme_bw(base_family="Times", base_size=10)

p_rhat_iter + p_se_iter + p_vi_iter + p_iter_time + 
    plot_layout(nrow=1, guides="collect") &
    theme(legend.position="bottom", 
          legend.box.margin=margin(0, 0, 0, 0),
          legend.margin=margin(0, 0, 0, 0),
          plot.background=element_blank(),
          plot.margin=unit(c(0, 1, 0, 0.2), "mm"))
ggsave(here("figures/compare_pa.pdf"), width=8, height=2.75)

acf_edem = mean(tapply(plans_mcmc_s$edem, plans_mcmc_s$chain, function(x) acf(x, plot=FALSE)[[1]][2]))
acf_comp = mean(tapply(plans_mcmc_s$comp, plans_mcmc_s$chain, function(x) acf(x, plot=FALSE)[[1]][2]))

rm(plans_mcmc)
rm(plans_mcmc_s)
rm(plans_smc_s)

##  partisan analysis plots -----
library(ggrepel)

# check r-hats
pl_rhat = plans_smc %>% 
    mutate(dem = group_frac(map, NDV, NDV+NRV)) %>% 
    as_tibble() %>% 
    `attr<-`("plans", NULL) %>%
    filter(draw != "cd_orig") %>%
    group_by(draw) %>%
    arrange(dem, .by_group=TRUE) %>%
    mutate(district = 1:18)
rhats = map_dbl(1:18, ~ with(filter(pl_rhat, district==.), redist:::diag_rhat(dem, chain)))
print(rhats)

plans_smc = plans_smc %>%
    filter(chain == 1) %>%
    select(-chain) %>%
    subset_sampled() %>%
    add_reference(as.integer(map$cd_orig), "General Assembly") %>%
    add_reference(as.integer(map$cd_court), "Court") %>%
    add_reference(as.integer(map$cd_league), "Petitioner") %>%
    add_reference(as.integer(map$cd_resp), "Respondent") %>%
    add_reference(as.integer(map$cd_gov), "Governor") %>%
    add_reference(as.integer(map$cd_house), "House Democrats") %>%
    mutate(comp = comp_frac_kept(., map),
           spl = county_splits(map, county),
           dem = group_frac(map, NDV, NDV+NRV))

d_comp_spl = as_tibble(plans_smc) %>%
    distinct(draw, comp, spl)

p1 = filter(d_comp_spl, as.integer(draw) > 2) %>%
ggplot(aes(1-comp)) +
    geom_density(adjust=1.5, color=NA, fill="#666666") +
    geom_vline(aes(xintercept=1-comp, color=draw, lty=draw),
               data=filter(d_comp_spl, as.integer(draw) <= 6),
               size=0.7) +
    geom_label_repel(aes(fill=draw, label=draw), y = 250,
                     data=filter(d_comp_spl, as.integer(draw) <= 6),
                     label.r=0.0, label.padding=0.1, 
                     fontface="bold", family="Times", size=2.5) +
    scale_x_continuous(labels=percent) +
    scale_y_continuous(expand=expansion(mult=c(0, 0.05))) +
    scale_color_manual(values=PAL[c(4, 3, 6, 5, 1, 2)]) +
    scale_fill_manual(values=PAL[c(4, 3, 6, 5, 1, 2)]) +
    labs(x="Fraction of edges removed", y="Density",
         title="(a) Compactness") +
    guides(color="none", lty="none", fill="none") +
    theme_bw(base_size=10, base_family="Times") +
    theme(legend.position=c(0.2, 0.8),
          legend.background=element_blank()) 

p2 = filter(d_comp_spl, as.integer(draw) > 2) %>%
ggplot(aes(spl)) +
    geom_histogram(binwidth=1, center=0, color=NA, fill="#666666") +
    geom_vline(aes(xintercept=spl, color=draw, lty=draw),
               data=filter(d_comp_spl, as.integer(draw) <= 6),
               size=0.7) +
    geom_label_repel(aes(fill=draw, label=draw), y = 1100,
                     data=filter(d_comp_spl, as.integer(draw) <= 6),
                     label.r=0.0, label.padding=0.1, 
                     fontface="bold", family="Times", size=2.5) +
    scale_y_continuous(expand=expansion(mult=c(0, 0.05))) +
    scale_color_manual(values=PAL[c(4, 3, 6, 5, 1, 2)]) +
    scale_fill_manual(values=PAL[c(4, 3, 6, 5, 1, 2)]) +
    labs(x="County splits", y="Number of samples",
         title="(b) County splits") +
    guides(color="none", lty="none", fill="none") +
    theme_bw(base_size=10, base_family="Times") +
    theme(legend.position=c(0.2, 0.8),
          legend.background=element_blank()) 


d_dem = filter(plans_smc, 
       !draw %in% c("Petitioner", "Respondent", "Governor", "House Democrats")) %>%
    as_tibble() %>%
    select(draw, district, dem) %>%
    group_by(draw) %>%
    arrange(dem, .by_group=TRUE) %>%
    mutate(district = 1:18) %>%
    ungroup()

p3 = filter(d_dem, as.integer(draw) > 2) %>%
ggplot(aes(district, dem, group=district)) +
    geom_hline(yintercept=0.5, color="#00000088") +
    geom_boxplot(size=0.25, outlier.size=0.05) +
    geom_point(aes(fill=draw, shape=draw),
               data=filter(d_dem, as.integer(draw) <= 2),
               size=2.0, color="#00000088") +
    scale_y_continuous(labels=percent) +
    scale_fill_manual(values=PAL) +
    scale_shape_manual(values=c(21, 23)) +
    labs(x="District, from least to most Democratic",
         y="Democratic vote share",
         title="(c) District-level Democratic vote shares",
         fill="Plan", shape="Plan") +
    theme_bw(base_size=10, base_family="Times") +
    theme(legend.position=c(0.32, 0.8),
          legend.background=element_blank()) 

((p1 / p2) | p3) + plot_layout(widths=c(2, 3))
ggsave(here("figures/pa_summary.pdf"), width=8.5, height=5)

d_dem %>%
    group_by(draw) %>%
    summarize(dem = sum(dem > 0.5)) %>%
    pull(dem) %>%
    table()

d_dem %>%
    filter(district == 9, draw != "Court") %>%
    pull(dem) %>%
    (\(x) mean(x[1] > x)) %>%
    as.character()


# r-hats for AoAS reviewer

pl = plans_smc %>%
    subset_sampled() %>%
    mutate(dem = group_frac(map, NDV, NDV+NRV)) |> 
    number_by(dem)

res_rhat = pl |> 
    group_by(district) |> 
    summarize(rhat_mean = redist:::diag_rhat(dem, chain),
              rhat_q05 = redist:::diag_rhat(dem < quantile(dem, 0.05), chain),
              rhat_q25 = redist:::diag_rhat(dem < quantile(dem, 0.25), chain),
              rhat_q50 = redist:::diag_rhat(dem < quantile(dem, 0.50)+1e-5, chain),
              rhat_q75 = redist:::diag_rhat(dem < quantile(dem, 0.75), chain),
              rhat_q95 = redist:::diag_rhat(dem < quantile(dem, 0.95), chain)
              )

redist:::diag_rhat(plans_smc_s$edem, plans_smc_s$chain)
